function model_output=gfop_ces_ge_solver(V_ik,D_i,lambda_ink,beta_nk,gamma,theta,rho,t_ink,s_ink,K,N)

%This function solves for the gains from optimal trade and industrial policy,
%starting from the initial equilibrium which is assumed to feature no trade
%or industrial policy.

V_i=sum(V_ik,3);

%Precalcs
theta_ink=repmat(theta,[N N 1]);
gamma_ink=repmat(gamma,[1 N 1]); %note that gamma here is exporter-industry-specific
tempsub_ink = t_ink.*(1+s_ink)-s_ink;
beta_nnk = repmat(beta_nk,[N 1 1]);


%Function to calculate lambda_hat given wage and employment
function lambda_hat=lambda_hat_calc(w_i,lhat_ik)
w_ink=repmat(w_i,[1,N,K]);
lhat_ink=repmat(lhat_ik,[1 N 1]);

lambda_hat_h = (1-t_ink).*(w_ink./((1-t_ink).*(1+s_ink).*(lhat_ink.^gamma_ink))).^(-theta_ink); %numerator of lambda_hat
lambda_nohat = lambda_hat_h.*lambda_ink; %terms of denominator of lambda_hat
lambda_nohat_sum = sum(lambda_nohat,1);
lambda_nohat_sum=repmat(lambda_nohat_sum, [N 1 1]);
lambda_hat = lambda_hat_h./lambda_nohat_sum;
end

%Function to calculate R given wages and employment
function Rprime=Rprime_calc(w_i,lhat_ik,bhat_nk)
%First build an NxN matrix M
lambda_hat=lambda_hat_calc(w_i,lhat_ik);
bhat_nnk = repmat(bhat_nk, [N 1 1]);
M_ink = lambda_hat.*lambda_ink.*tempsub_ink.*bhat_nnk.*beta_nnk;
%Now collapse to NxN by summing along the industry dimension
M_in=sum(M_ink,3);
%Now solve matrix equation for revenue
Rprime = (eye(N)-M_in)\(M_in*(w_i.*V_i+D_i));
end

%Function to update lhat, given wages and initial lhat
    function lhat=lhat_calc(w_i,lhat_ik,bhat_nk)
        lambda_hat=lambda_hat_calc(w_i,lhat_ik);
        Rprime=Rprime_calc(w_i,lhat_ik,bhat_nk);
        bhat_nnk = repmat(bhat_nk, [N 1 1]);
        %Calculating the appropriate intermediate matrices
        temp_mat1 = (1+s_ink).*(1-t_ink);
        temp_mat2=lambda_hat.*lambda_ink;
        temp_mat=temp_mat1.*temp_mat2;
 
        temp_vec=w_i.*V_i+Rprime+D_i;
        agg_mat = repmat(temp_vec',[N 1 K]);
        agg_mat = agg_mat.*bhat_nnk.*beta_nnk;
        tot_mat = temp_mat.*agg_mat;
        LHS=sum(tot_mat,2);
        %Now isolating lhat
        w_ik = repmat(w_i,[1 1 K]);
        lhat = LHS./(w_ik.*V_ik);
    end
      
%Function to calculate betahat, given wages, lhat
    function bhat_nk=bhat_calc(w_i,lhat_ik)
        w_ink=repmat(w_i,[1,N,K]);
        lhat_ink=repmat(lhat_ik,[1 N 1]);
        lambda_hat_h = (1-t_ink).*(w_ink./((1-t_ink).*(1+s_ink).*(lhat_ink.^gamma_ink))).^(-theta_ink); %numerator of lambda_hat
        lambda_nohat = lambda_hat_h.*lambda_ink; %terms of denominator of lambda_hat
        price_industry_nk = sum(lambda_nohat,1).^((1-rho)./(-theta_ink(1,:,:))); %beta numerator
        price_agg_n = sum(price_industry_nk.*beta_nk,3);
        price_agg_nk = repmat(price_agg_n, [1 1 K]);
        bhat_nk = price_industry_nk./price_agg_nk;
    end
        

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Algorithm to compute equilibrium

%Initial values
w_i=ones(N,1,1);
lhat_ik=ones(N,1,K);
bhat_nk = ones(1,N,K);

%Max Iterations
max_lhat_iter=10;
max_bhat_iter = 10;
max_w_iter=100;

%Error tolerance
max_e_w=.0001;
max_e_lhat=.0001;
max_e_bhat = .0001;

%Step sizes
ss_w=.1;
ss_lhat=.25;
ss_bhat=.25;

%%%%%%%%%%%%%
%Algorithm
for i=1:max_w_iter
    for j=1:max_bhat_iter
        for z=1:max_lhat_iter
            lhat_new = lhat_calc(w_i,lhat_ik,bhat_nk);
            error_lhat=lhat_new-lhat_ik;
            Z_lhat=squeeze(sum(sum(abs(error_lhat))));
            iter_l=z;
            R_i = Rprime;
            lhat_ik=lhat_ik+ss_lhat*(lhat_new-lhat_ik);
            lhat=lhat_ik;
            if Z_lhat<max_e_lhat
                break;
            end
            
        end
        bhat_new_nk=bhat_calc(w_i,lhat_ik);
        %bhat_new_nk = bhat_nk;
        error_bhat=bhat_new_nk-bhat_nk;
        Z_bhat=squeeze(sum(sum(abs(error_bhat))));
            iter_b=j;
            bhat_nk=bhat_nk+ss_bhat*(bhat_new_nk-bhat_nk);
            if Z_bhat<max_e_bhat
                break;
            end
     end
       %Update wages based on excess demand for labor
       excess_lab_dem=sum(lhat_ik.*V_ik,3)./sum(V_ik,3)-1;
       Z_w = sum(abs(excess_lab_dem))
       w_i=w_i+ss_w*(w_i.*excess_lab_dem);
       iterw=i
       %Renormalize wages
       w_i=w_i/(sum(w_i.*V_i));
       if Z_w<max_e_w
           break;
       end
        
end   

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Model output
lambda_hat=lambda_hat_calc(w_i,lhat_ik);
Rprime=Rprime_calc(w_i,lhat_ik,bhat_nk);

%Calculating Final Price Indices
        w_ink=repmat(w_i,[1,N,K]);
        lhat_ink=repmat(lhat_ik,[1 N 1]);
        lambda_hat_h = (1-t_ink).*(w_ink./((1-t_ink).*(1+s_ink).*(lhat_ink.^gamma_ink))).^(-theta_ink); %numerator of lambda_hat
        lambda_nohat = lambda_hat_h.*lambda_ink; %terms of denominator of lambda_hat
        price_industry_nk = sum(lambda_nohat,1).^((1-rho)./(-theta_ink(1,:,:))); %beta numerator
        P_i = squeeze(sum(price_industry_nk.*beta_nk,3)).^(1/(1-rho));
        P_i = P_i';

%Hat change in welfare
welfare_i = ((w_i.*V_i+R_i+D_i)./(V_i+D_i))./P_i;
ri_i=((w_i.*V_i+R_i)./V_i)./P_i;

model_output.welfare_i=welfare_i;
model_output.ri_i=ri_i;
model_output.lambda_hat = lambda_hat;
model_output.w_i=w_i;
model_output.R_i=R_i;
model_output.lhat=lhat;
model_output.P_i=P_i;
Z_lhat
Z_bhat
end
